package com.yeskery.nut.application.netty;

import com.yeskery.nut.util.StringUtils;
import com.yeskery.nut.websocket.Session;
import com.yeskery.nut.websocket.WebSocketConfiguration;
import com.yeskery.nut.websocket.WebSocketConfigurationRegistry;
import com.yeskery.nut.websocket.netty.NettyWebSocketSession;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler;
import io.netty.util.CharsetUtil;

import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * Netty WebSocket服务处理器
 * @author sprout
 * @version 1.0
 * 2023-04-22 14:37
 */
public class NettyWebSocketServerContext implements WebSocketServerContext {

    /** 日志对象 */
    private static final Logger logger = Logger.getLogger(NettyWebSocketServerContext.class.getName());

    /** ws协议 */
    private static final String WS_PROTOCOL = "ws://";

    /** wss协议 */
    private static final String WSS_PROTOCOL = "wss://";

    /** WebSocket配置注册中心 */
    private final WebSocketConfigurationRegistry webSocketConfigurationRegistry;

    /** 是否以安全方式启动 */
    private final boolean secure;

    /**
     * 构建WebSocket服务处理器
     * @param webSocketConfigurationRegistry WebSocket 配置注册中心
     * @param secure 是否以安全方式启动
     */
    public NettyWebSocketServerContext(WebSocketConfigurationRegistry webSocketConfigurationRegistry, boolean secure) {
        this.webSocketConfigurationRegistry = webSocketConfigurationRegistry;
        this.secure = secure;
    }

    @Override
    public boolean isWebSocketRequest(ChannelHandlerContext ctx, FullHttpRequest request){
        QueryStringDecoder decoder = new QueryStringDecoder(request.uri());
        WebSocketConfiguration webSocketConfiguration = webSocketConfigurationRegistry.findWebSocketConfiguration(decoder.rawPath());
        if (webSocketConfiguration == null) {
            return false;
        }
        channelReadWebSocket(ctx, request, decoder, webSocketConfiguration);
        return true;
    }

    /**
     * 处理websocket请求
     * @param ctx 通道处理器上下文
     * @param req 请求对象
     * @param decoder 请求参数解码器
     * @param webSocketConfiguration WebSocket配置对象
     */
    private void channelReadWebSocket(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder,
                                      WebSocketConfiguration webSocketConfiguration) {
        FullHttpResponse res;
        if (!req.decoderResult().isSuccess()) {
            res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST);
            sendWebSocketResponse(ctx, req, res);
            return;
        }

        if (req.method() != HttpMethod.GET) {
            res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN);
            sendWebSocketResponse(ctx, req, res);
            return;
        }

        String host = req.headers().get(HttpHeaderNames.HOST);
        if (StringUtils.isEmpty(host)) {
            res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN);
            sendWebSocketResponse(ctx, req, res);
            return;
        }

        if (!req.headers().contains(HttpHeaderValues.UPGRADE) || !req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY)
                || !req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_VERSION)) {
            res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN);
            sendWebSocketResponse(ctx, req, res);
            return;
        }

        try {
            handleWebSocketRequest(ctx, req, decoder, webSocketConfiguration);
        } catch (Exception e) {
            res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR);
            sendWebSocketResponse(ctx, req, res);
            logger.logp(Level.SEVERE, NettyWebSocketServerContext.class.getName(), "channelReadWebSocket",
                    "Netty WebSocket Execute Fail.", e);
        }
    }

    /**
     * 处理websocket请求
     * @param ctx 通道处理器上下文
     * @param req 请求上下文
     * @param decoder 请求参数解码器
     * @param webSocketConfiguration WebSocket配置对象
     */
    private void handleWebSocketRequest(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder,
                                        WebSocketConfiguration webSocketConfiguration) {
        Channel channel = ctx.channel();
        if (!channel.isActive()) {
            return;
        }

        String fullUrl = secure ? WSS_PROTOCOL : WS_PROTOCOL + req.headers().get(HttpHeaderNames.HOST) + req.uri();
        WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(fullUrl, null, true);
        WebSocketServerHandshaker handShaker = wsFactory.newHandshaker(req);
        if (handShaker == null) {
            WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(channel);
        } else {
            ChannelPipeline pipeline = ctx.pipeline();
            pipeline.remove(ctx.name());
            pipeline.addLast(new WebSocketServerCompressionHandler());
            // 管道添加WebSocketServerHandler
            pipeline.addLast(new NettyWebSocketHandler(decoder, webSocketConfiguration));
            handShaker.handshake(channel, req).addListener(future -> {
                if (future.isSuccess()) {
                    Session session = new NettyWebSocketSession(ctx, decoder, webSocketConfiguration.getWebSocketServerConfigure());
                    webSocketConfiguration.getWebSocketHandler().onOpen(session);
                } else {
                    handShaker.close(channel, new CloseWebSocketFrame());
                }
            });
        }
    }

    /**
     * 发送WebSocket响应
     * @param ctx 通道处理器上下文
     * @param req 请求对象
     * @param res 响应对象
     */
    private void sendWebSocketResponse(ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) {
        int statusCode = res.status().code();
        if (statusCode != HttpResponseStatus.OK.code() && res.content().readableBytes() == 0) {
            ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(), CharsetUtil.UTF_8);
            res.content().writeBytes(buf);
            buf.release();
        }
        HttpUtil.setContentLength(res, res.content().readableBytes());

        ChannelFuture f = ctx.channel().writeAndFlush(res);
        if (!HttpUtil.isKeepAlive(req) || statusCode != HttpResponseStatus.OK.code()) {
            f.addListener(ChannelFutureListener.CLOSE);
        }
    }

}
